import math, torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from .activation import *

def initilize(layer):
    if isinstance(layer, nn.Conv2d):
        nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
    elif isinstance(layer, nn.Linear):
        nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
        if layer.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(layer.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.normal_(layer.bias, -bound, bound)

class LinearNet(nn.Module):
    def __init__(self, cfg, *args, **kwargs):
        super().__init__()
        self.input_dim = np.prod(kwargs["input_shape"])
        self.hidden_dims = cfg.Classifier.hidden_dims
        self.output_dim = kwargs["num_classes"]
        self.normalization = cfg.Classifier.normalization

        self.activation_name = cfg.Classifier.activation_name
        self.activation = ModifiedRelu if self.activation_name == "ModifiedRelu" else nn.ReLU

        self.model = nn.Sequential()
        for input_dim, output_dim in zip(
                [self.input_dim] + self.hidden_dims[:-1],
                self.hidden_dims
            ):
            self.model.append(nn.Linear(input_dim, output_dim))
            if self.normalization:
                self.model.append(nn.BatchNorm1d(output_dim))
            self.model.append(self.activation())
        self.model.append(nn.Linear(self.hidden_dims[-1], self.output_dim))
        
        self.model.apply(initilize)

    def forward(self, x):
        x = x.view(x.shape[0], self.input_dim)
        x = self.model(x)
        return x

class ConvNet(nn.Module):
    def __init__(self, cfg, *args, **kwargs):
        super().__init__()
        self.output_dim = kwargs["num_classes"]
        self.input_dim = kwargs["num_channels"]
        self.conv1 = nn.Conv2d(self.input_dim , 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, self.output_dim)

        self.conv1.apply(initilize)
        self.conv2.apply(initilize)
        self.fc1.apply(initilize)
        self.fc2.apply(initilize)
        self.fc3.apply(initilize)

        self.activation_name = cfg.Classifier.activation_name
        self.activation = ModifiedRelu if self.activation_name == "ModifiedRelu" else nn.ReLU

    def forward(self, x):
        x = self.pool(self.activation()(self.conv1(x)))
        x = self.pool(self.activation()(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = self.activation()(self.fc1(x))
        x = self.activation()(self.fc2(x))
        x = self.fc3(x)
        return x

class ResNet(nn.Module):
    def __init__(self, cfg, *args, **kwargs):
        super().__init__()
        pretrain = cfg.Classifier.pretrain
        self.model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=pretrain)
        self.output_dim = kwargs["num_classes"]
        for param in self.model.parameters():
            param.requires_grad = False
        self.model.fc =nn.Linear(512, self.output_dim)
        # self.model.fc.apply(initilize)
        for param in self.model.fc.parameters():
            param.requires_grad = True

    def forward(self, x):
        x = self.model(x)
        return x